import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import constant_init, kaiming_init, normal_init, xavier_init
from mmcv.runner import load_checkpoint

from .vgg import VGG
from mmdet.utils import get_root_logger
from ..builder import BACKBONES

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, groups=g, bias=False)
        # self.bn = nn.BatchNorm2d(c2)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.conv(x))
        # return self.act(self.bn(self.conv(x)))

class SPP(nn.Module):
    # Spatial pyramid pooling layer used in YOLOv3-SPP
    def __init__(self, c1, c2, k=(5, 9, 13)):
        super(SPP, self).__init__()
        c_ = c1 // 2  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

    def forward(self, x):
        x = self.cv1(x)
        return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))

@BACKBONES.register_module()
class SSDVGG(VGG):
    """VGG Backbone network for single-shot-detection.

    Args:
        input_size (int): width and height of input, from {300, 512}.
        depth (int): Depth of vgg, from {11, 13, 16, 19}.
        out_indices (Sequence[int]): Output from which stages.

    Example:
        >>> self = SSDVGG(input_size=300, depth=11)
        >>> self.eval()
        >>> inputs = torch.rand(1, 3, 300, 300)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 1024, 19, 19)
        (1, 512, 10, 10)
        (1, 256, 5, 5)
        (1, 256, 3, 3)
        (1, 256, 1, 1)
    """
    extra_setting = {
        300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
        512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
    }

    def __init__(self,
                 input_size,
                 depth,
                 with_last_pool=False,
                 ceil_mode=True,
                 out_indices=(3, 4),
                 out_feature_indices=(22, 34),
                 l2_norm_scale=20.):
        # TODO: in_channels for mmcv.VGG
        super(SSDVGG, self).__init__(
            depth,
            norm_cfg=dict(type='BN', requires_grad=True),
            with_last_pool=with_last_pool,
            ceil_mode=ceil_mode,
            out_indices=out_indices)
        assert input_size in (300, 512)
        self.input_size = input_size
        self.features.add_module(
            str(len(self.features)),
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
        self.features.add_module(
            str(len(self.features)),
            nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6, bias=False))
        self.features.add_module(
            str(len(self.features)), nn.BatchNorm2d(1024))
        self.features.add_module(
            str(len(self.features)), nn.ReLU(inplace=True))
#         self.features.add_module(
#             str(len(self.features)),
#             SPP(1024, 1024),
#         )
        self.features.add_module(
            str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1, bias=False))
        self.features.add_module(
            str(len(self.features)), nn.BatchNorm2d(1024))
        self.features.add_module(
            str(len(self.features)), nn.ReLU(inplace=True))
#         self.features.add_module(
#             str(len(self.features)),
#             SPP(1024, 1024),
#         )
#         self.spp1 = SPP(512, 512, k=(11, 19, 37))
#         self.spp2 = SPP(1024, 1024, k=(5, 11, 19))
        self.out_feature_indices = out_feature_indices

        self.inplanes = 1024
        self.extra = self._make_extra_layers(self.extra_setting[input_size])
        self.l2_norm = L2Norm(
            self.features[out_feature_indices[0]].out_channels,
            l2_norm_scale)

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            print("load pretrained weights from {}".format(pretrained))
            model_dict = torch.load(pretrained)
            if 'state_dict' in model_dict:
                model_dict = model_dict['state_dict']
            state_dict = {}
            for k,v in model_dict.items():
                state_dict[k[9:]] = v
            # print(state_dict.keys())
            missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
            print("missing kyes:", missing_keys)
            print("unexpected keys:", unexpected_keys)            
#         load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.features.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)
                elif isinstance(m, nn.Linear):
                    normal_init(m, std=0.01)
        else:
            raise TypeError('pretrained must be a str or None')

        for m in self.extra.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
                # xavier_init(m, distribution='uniform')
            elif isinstance(m, nn.BatchNorm2d):
                constant_init(m, 1)
            elif isinstance(m, nn.Linear):
                normal_init(m, std=0.01)

        constant_init(self.l2_norm, self.l2_norm.scale)

    def forward(self, x):
        """Forward function."""
        outs = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in self.out_feature_indices:
                outs.append(x)
#             if i in self.out_feature_indices:
#                 if i == self.out_feature_indices[0]:
#                     outs.append(self.spp1(x))
#                 else:
#                     outs.append(self.spp2(x))
        for i, layer in enumerate(self.extra):
            x = F.relu(layer(x), inplace=True)
            if i % 2 == 1:
                outs.append(x)
        outs[0] = self.l2_norm(outs[0])
        if len(outs) == 1:
            return outs[0]
        else:
            return tuple(outs)

    def _make_extra_layers(self, outplanes):
        layers = []
        kernel_sizes = (1, 3)
        num_layers = 0
        outplane = None
        for i in range(len(outplanes)):
            if self.inplanes == 'S':
                self.inplanes = outplane
                continue
            k = kernel_sizes[num_layers % 2]
            if outplanes[i] == 'S':
                outplane = outplanes[i + 1]
#                 conv = nn.Conv2d(
#                     self.inplanes, outplane, k, stride=2, padding=1)
                conv = nn.Sequential(nn.Conv2d(
                                        self.inplanes, outplane, k, stride=2, padding=1, bias=False),
                                     nn.BatchNorm2d(outplane),
                                     nn.ReLU())
            else:
                outplane = outplanes[i]
#                 conv = nn.Conv2d(
#                     self.inplanes, outplane, k, stride=1, padding=0)
                conv = nn.Sequential(nn.Conv2d(
                                        self.inplanes, outplane, k, stride=1, padding=0, bias=False),
                                     nn.BatchNorm2d(outplane),
                                     nn.ReLU())
            layers.append(conv)
            self.inplanes = outplanes[i]
            num_layers += 1
        if self.input_size == 512:
#             layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))
            conv = nn.Sequential(nn.Conv2d(
                                        self.inplanes, 256, 4, padding=1, bias=False),
                                     nn.BatchNorm2d(256),
                                     nn.ReLU())
            layers.append(conv)
        return nn.Sequential(*layers)


class L2Norm(nn.Module):

    def __init__(self, n_dims, scale=20., eps=1e-10):
        """L2 normalization layer.

        Args:
            n_dims (int): Number of dimensions to be normalized
            scale (float, optional): Defaults to 20..
            eps (float, optional): Used to avoid division by zero.
                Defaults to 1e-10.
        """
        super(L2Norm, self).__init__()
        self.n_dims = n_dims
        self.weight = nn.Parameter(torch.Tensor(self.n_dims))
        self.eps = eps
        self.scale = scale

    def forward(self, x):
        """Forward function."""
        # normalization layer convert to FP32 in FP16 training
        x_float = x.float()
        norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
        return (self.weight[None, :, None, None].float().expand_as(x_float) *
                x_float / norm).type_as(x)
